{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import random\n",
    "import math\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">> Converting image 1000/1000 shard 4\n",
      ">> Converting image 500/500 shard 4\n"
     ]
    }
   ],
   "source": [
    "#验证集数量\n",
    "_NUM_TEST = 500\n",
    "#随机种子\n",
    "_RANDOM_SEED = 0\n",
    "#数据块\n",
    "_NUM_SHARDS = 5\n",
    "#数据集路径\n",
    "DATASET_DIR = \"D:/Tensorflow/slim/images/\"\n",
    "#标签文件名字\n",
    "LABELS_FILENAME = \"D:/Tensorflow/slim/images/labels.txt\"\n",
    "\n",
    "#定义tfrecord文件的路径+名字\n",
    "def _get_dataset_filename(dataset_dir, split_name, shard_id):\n",
    "    output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)\n",
    "    return os.path.join(dataset_dir, output_filename)\n",
    "\n",
    "#判断tfrecord文件是否存在\n",
    "def _dataset_exists(dataset_dir):\n",
    "    for split_name in ['train', 'test']:\n",
    "        for shard_id in range(_NUM_SHARDS):\n",
    "            #定义tfrecord文件的路径+名字\n",
    "            output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)\n",
    "        if not tf.gfile.Exists(output_filename):\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "#获取所有文件以及分类\n",
    "def _get_filenames_and_classes(dataset_dir):\n",
    "    #数据目录\n",
    "    directories = []\n",
    "    #分类名称\n",
    "    class_names = []\n",
    "    for filename in os.listdir(dataset_dir):\n",
    "        #合并文件路径\n",
    "        path = os.path.join(dataset_dir, filename)\n",
    "        #判断该路径是否为目录\n",
    "        if os.path.isdir(path):\n",
    "            #加入数据目录\n",
    "            directories.append(path)\n",
    "            #加入类别名称\n",
    "            class_names.append(filename)\n",
    "\n",
    "    photo_filenames = []\n",
    "    #循环每个分类的文件夹\n",
    "    for directory in directories:\n",
    "        for filename in os.listdir(directory):\n",
    "            path = os.path.join(directory, filename)\n",
    "            #把图片加入图片列表\n",
    "            photo_filenames.append(path)\n",
    "\n",
    "    return photo_filenames, class_names\n",
    "\n",
    "def int64_feature(values):\n",
    "    if not isinstance(values, (tuple, list)):\n",
    "        values = [values]\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))\n",
    "\n",
    "def bytes_feature(values):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))\n",
    "\n",
    "def image_to_tfexample(image_data, image_format, class_id):\n",
    "    #Abstract base class for protocol messages.\n",
    "    return tf.train.Example(features=tf.train.Features(feature={\n",
    "      'image/encoded': bytes_feature(image_data),\n",
    "      'image/format': bytes_feature(image_format),\n",
    "      'image/class/label': int64_feature(class_id),\n",
    "    }))\n",
    "\n",
    "def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):\n",
    "    labels_filename = os.path.join(dataset_dir, filename)\n",
    "    with tf.gfile.Open(labels_filename, 'w') as f:\n",
    "        for label in labels_to_class_names:\n",
    "            class_name = labels_to_class_names[label]\n",
    "            f.write('%d:%s\\n' % (label, class_name))\n",
    "\n",
    "#把数据转为TFRecord格式\n",
    "def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):\n",
    "    assert split_name in ['train', 'test']\n",
    "    #计算每个数据块有多少数据\n",
    "    num_per_shard = int(len(filenames) / _NUM_SHARDS)\n",
    "    with tf.Graph().as_default():\n",
    "        with tf.Session() as sess:\n",
    "            for shard_id in range(_NUM_SHARDS):\n",
    "                #定义tfrecord文件的路径+名字\n",
    "                output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)\n",
    "                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:\n",
    "                    #每一个数据块开始的位置\n",
    "                    start_ndx = shard_id * num_per_shard\n",
    "                    #每一个数据块最后的位置\n",
    "                    end_ndx = min((shard_id+1) * num_per_shard, len(filenames))\n",
    "                    for i in range(start_ndx, end_ndx):\n",
    "                        try:\n",
    "                            sys.stdout.write('\\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))\n",
    "                            sys.stdout.flush()\n",
    "                            #读取图片\n",
    "                            image_data = tf.gfile.FastGFile(filenames[i], 'r').read()\n",
    "                            #获得图片的类别名称\n",
    "                            class_name = os.path.basename(os.path.dirname(filenames[i]))\n",
    "                            #找到类别名称对应的id\n",
    "                            class_id = class_names_to_ids[class_name]\n",
    "                            #生成tfrecord文件\n",
    "                            example = image_to_tfexample(image_data, b'jpg', class_id)\n",
    "                            tfrecord_writer.write(example.SerializeToString())\n",
    "                        except IOError as e:\n",
    "                            print(\"Could not read:\",filenames[i])\n",
    "                            print(\"Error:\",e)\n",
    "                            print(\"Skip it\\n\")\n",
    "                            \n",
    "    sys.stdout.write('\\n')\n",
    "    sys.stdout.flush()\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    #判断tfrecord文件是否存在\n",
    "    if _dataset_exists(DATASET_DIR):\n",
    "        print('tfcecord文件已存在')\n",
    "    else:\n",
    "        #获得所有图片以及分类\n",
    "        photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)\n",
    "        #把分类转为字典格式，类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}\n",
    "        class_names_to_ids = dict(zip(class_names, range(len(class_names))))\n",
    "\n",
    "        #把数据切分为训练集和测试集\n",
    "        random.seed(_RANDOM_SEED)\n",
    "        random.shuffle(photo_filenames)\n",
    "        training_filenames = photo_filenames[_NUM_TEST:]\n",
    "        testing_filenames = photo_filenames[:_NUM_TEST]\n",
    "\n",
    "        #数据转换\n",
    "        _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)\n",
    "        _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR)\n",
    "\n",
    "        #输出labels文件\n",
    "        labels_to_class_names = dict(zip(range(len(class_names)), class_names))\n",
    "        write_label_file(labels_to_class_names, DATASET_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
